Collecting clu Downloading clu-0.0.12-py3-none-any.whl.metadata (1.9 kB) Requirement already satisfied: absl-py in /usr/local/lib/python3.11/dist-packages (from clu) (1.4.0) Requirement already satisfied: etils[epath] in /usr/local/lib/python3.11/dist-packages (from clu) (1.12.2) Requirement already satisfied: flax in /usr/local/lib/python3.11/dist-packages (from clu) (0.10.6) Requirement already satisfied: jax in /usr/local/lib/python3.11/dist-packages (from clu) (0.5.2) Requirement already satisfied: jaxlib in /usr/local/lib/python3.11/dist-packages (from clu) (0.5.1) Collecting ml-collections (from clu) Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB) Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from clu) (2.0.2) Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from clu) (24.2) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.11/dist-packages (from clu) (4.14.1) Requirement already satisfied: wrapt in /usr/local/lib/python3.11/dist-packages (from clu) (1.17.2) Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from etils[epath]->clu) (2025.3.2) Requirement already satisfied: importlib_resources in /usr/local/lib/python3.11/dist-packages (from etils[epath]->clu) (6.5.2) Requirement already satisfied: zipp in /usr/local/lib/python3.11/dist-packages (from etils[epath]->clu) (3.23.0) Requirement already satisfied: msgpack in /usr/local/lib/python3.11/dist-packages (from flax->clu) (1.1.1) Requirement already satisfied: optax in /usr/local/lib/python3.11/dist-packages (from flax->clu) (0.2.5) Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.11/dist-packages (from flax->clu) (0.11.16) Requirement already satisfied: tensorstore in /usr/local/lib/python3.11/dist-packages (from flax->clu) (0.1.74) Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.11/dist-packages (from flax->clu) (13.9.4) Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.11/dist-packages (from flax->clu) (6.0.2) Requirement already satisfied: treescope>=0.1.7 in /usr/local/lib/python3.11/dist-packages (from flax->clu) (0.1.9) Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from jax->clu) (0.4.1) Requirement already satisfied: opt_einsum in /usr/local/lib/python3.11/dist-packages (from jax->clu) (3.4.0) Requirement already satisfied: scipy>=1.11.1 in /usr/local/lib/python3.11/dist-packages (from jax->clu) (1.15.3) Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=11.1->flax->clu) (3.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=11.1->flax->clu) (2.19.2) Requirement already satisfied: chex>=0.1.87 in /usr/local/lib/python3.11/dist-packages (from optax->flax->clu) (0.1.89) Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax->clu) (1.6.0) Requirement already satisfied: protobuf in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax->clu) (5.29.5) Requirement already satisfied: humanize in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax->clu) (4.12.3) Requirement already satisfied: simplejson>=3.16.0 in /usr/local/lib/python3.11/dist-packages (from orbax-checkpoint->flax->clu) (3.20.1) Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from chex>=0.1.87->optax->flax->clu) (0.12.1) Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax->clu) (0.1.2) Downloading clu-0.0.12-py3-none-any.whl (101 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 101.8/101.8 kB 4.7 MB/s eta 0:00:00 Downloading ml_collections-1.1.0-py3-none-any.whl (76 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 76.7/76.7 kB 7.8 MB/s eta 0:00:00 Installing collected packages: ml-collections, clu Successfully installed clu-0.0.12 ml-collections-1.1.0
import jax.numpy as jnp
import jax
import flax.linen as nn
import flax
import tensorflow_datasets as tfds
from functools import partial
import numpy as np
from flax.training import train_state # Useful dataclass to keep train state
from flax import struct # Flax dataclasses
import optax
from clu import metrics
from typing import Sequence, AnyShow the code
import tensorflow_datasets as tfds # TFDS for MNIST
import tensorflow as tf # TensorFlow operations
tf.random.set_seed(0) # set random seed for reproducibility
num_epochs = 10
batch_size = 32
train_ds: tf.data.Dataset = tfds.load('mnist', split='train')
test_ds: tf.data.Dataset = tfds.load('mnist', split='test')
train_ds = train_ds.map(
lambda sample: {
'image': tf.cast(sample['image'], tf.float32) / 255,
'label': sample['label'],
}
) # normalize train set
test_ds = test_ds.map(
lambda sample: {
'image': tf.cast(sample['image'], tf.float32) / 255,
'label': sample['label'],
}
) # normalize test set
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
train_ds = train_ds.repeat(num_epochs).shuffle(1024)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
train_ds = train_ds.batch(batch_size, drop_remainder=True).prefetch(1)
# create shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from
test_ds = test_ds.shuffle(1024)
# group into batches of batch_size and skip incomplete batch, prefetch the next sample to improve latency
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)WARNING:absl:Variant folder /root/tensorflow_datasets/mnist/3.0.1 has no dataset_info.json
Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.
Show the code
class down_block(nn.Module):
@nn.compact
def __call__(self,
inputs,
n_filter=32,
max_pooling=True,
training=True):
conv = nn.Conv(
n_filter,
3, # filter size
padding='SAME',
kernel_init=nn.initializers.he_normal())(inputs)
skip_connection = conv
conv = nn.BatchNorm(use_running_average=not training)(conv)
conv = nn.relu(conv)
conv = nn.Conv(
n_filter,
3, # filter size
padding='SAME',
kernel_init=nn.initializers.he_normal())(conv)
conv = nn.relu(conv)
if max_pooling:
next_layer = nn.max_pool(conv, window_shape=(2, 2), padding='SAME')
else:
next_layer = conv
return next_layer, skip_connection
class up_block(nn.Module):
@nn.compact
def __call__(self, inputs, skip_connection, filters, training=True):
if skip_connection is None:
x = inputs
else:
x = jnp.concatenate([inputs, skip_connection], axis=3)
x = nn.Conv(filters, 3, padding='SAME')(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.Conv(filters, 3, padding='SAME')(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
return x
class UNet(nn.Module):
@nn.compact
def __call__(self, input):
filter = [64, 128, 256]
# encode
x, temp1 = down_block()(input, filter[0])
x, temp2 = down_block()(x, filter[1])
x, _ = down_block()(x, filter[2], max_pooling=False)
# decode
x = up_block()(x, temp2, filter[1])
x = up_block()(x, temp1, filter[0])
x = up_block()(x, None, 1)
return x
unet = UNet()
# print(
# m.tabulate(jax.random.key(0),
# jnp.ones((1, 28, 28, 1)),
# compute_flops=True,
# compute_vjp_flops=True))Show the code
@struct.dataclass
class Metrics(metrics.Collection):
loss: metrics.Average.from_output('loss')
class TrainState(train_state.TrainState):
metrics: Metrics
batch_stats: Any
def create_train_state(module, rng, learning_rate, momentum):
"""Creates an initial `TrainState`."""
variables = module.init(rng, jnp.ones([1, 28, 28, 1]))
params = variables[
'params'] # initialize parameters by passing a template image
batch_stats = variables[
'batch_stats'] # initialize batch_stats by passing a template image
tx = optax.sgd(learning_rate, momentum)
return TrainState.create(apply_fn=module.apply,
params=params,
batch_stats=batch_stats,
tx=tx,
metrics=Metrics.empty())
@jax.jit
def train_step(state, batch):
"""Train for a single step."""
def loss_fn(params):
predicted, updates = state.apply_fn(
{
'params': params,
'batch_stats': state.batch_stats
},
batch['image'],
mutable=['batch_stats'],
rngs={'dropout': jax.random.key(1)})
loss = optax.losses.l2_loss(predictions=predicted,
targets=batch['image']).mean()
return loss, (predicted, updates)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, (predicted, updates)), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
state = state.replace(batch_stats=updates['batch_stats'])
return state
@jax.jit
def compute_metrics(*, state, batch):
predicted, updates = state.apply_fn(
{
'params': state.params,
'batch_stats': state.batch_stats
},
batch['image'],
mutable=['batch_stats'],
rngs={'dropout': jax.random.key(1)})
loss = optax.losses.l2_loss(predictions=predicted,
targets=batch['image']).mean()
metric_updates = state.metrics.single_from_model_output(
predictions=predicted, targets=batch['label'], loss=loss)
metrics = state.metrics.merge(metric_updates)
state = state.replace(metrics=metrics)
return state
num_epochs = 10
batch_size = 32
# train_ds, test_ds = get_datasets(num_epochs, batch_size)
# tf.random.set_seed(0)
init_rng = jax.random.key(0)
learning_rate = 0.01
momentum = 0.9
state = create_train_state(unet, init_rng, learning_rate, momentum)
del init_rng # Must not be used anymore.
# since train_ds is replicated num_epochs times in get_datasets(), we divide by num_epochs
num_steps_per_epoch = train_ds.cardinality().numpy() // num_epochs
num_steps_per_epoch = 2
print(num_steps_per_epoch)
metrics_history = {
'train_loss': [],
'test_loss': [],
}
test_summary_writer = tf.summary.create_file_writer('test/logdir')
for step, batch in enumerate(train_ds.as_numpy_iterator()):
if step > 20:
break
# Run optimization steps over training batches and compute batch metrics
state = train_step(
state, batch
) # get updated train state (which contains the updated parameters)
state = compute_metrics(state=state,
batch=batch) # aggregate batch metrics
if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed
for metric, value in state.metrics.compute().items(
): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
with test_summary_writer.as_default():
tf.summary.scalar(
'train/loss', value, step=step
)
state = state.replace(metrics=state.metrics.empty()
) # reset train_metrics for next training epoch
# Compute metrics on the test set after each training epoch
test_state = state
for test_batch in test_ds.as_numpy_iterator():
test_state = compute_metrics(state=test_state, batch=test_batch)
for metric, value in test_state.metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
with test_summary_writer.as_default():
tf.summary.scalar(
'test/loss', value, step=step
)
print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['train_loss'][-1]}, ")
print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['test_loss'][-1]}, ")
# writer.add_image('images', grid, 0)
# writer.add_graph(model, images)2
train epoch: 1, loss: 0.23241570591926575,
test epoch: 1, loss: 0.19601216912269592,
train epoch: 2, loss: 0.1533064991235733,
test epoch: 2, loss: 0.15013296902179718,
train epoch: 3, loss: 0.12592682242393494,
test epoch: 3, loss: 0.11877958476543427,
train epoch: 4, loss: 0.0943879559636116,
test epoch: 4, loss: 0.08934623003005981,
train epoch: 5, loss: 0.06853969395160675,
test epoch: 5, loss: 0.06333110481500626,
train epoch: 6, loss: 0.04726167768239975,
test epoch: 6, loss: 0.04230469465255737,
train epoch: 7, loss: 0.02610059641301632,
test epoch: 7, loss: 0.026588963344693184,
train epoch: 8, loss: 0.01785922423005104,
test epoch: 8, loss: 0.015624734573066235,
train epoch: 9, loss: 0.01218641921877861,
test epoch: 9, loss: 0.008579996414482594,
train epoch: 10, loss: 0.006207008380442858,
test epoch: 10, loss: 0.004718703217804432,